// Copyright 2022 The Liquigo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package translate

import (
	"errors"
	"fmt"
	"strconv"
	"strings"

	config "gitee.com/west0207/liquigo/core/config"
	db "gitee.com/west0207/liquigo/core/db"
	. "gitee.com/west0207/liquigo/core/log"
	utils "gitee.com/west0207/liquigo/core/utils"

	etree "gitee.com/west0207/etree"
)

// 获取changeSet的preConditionsEle标签的sql语句
// dbmsName *string 数据库类型名称，例如：mysql
// dbmsVersion *string 数据库系统版本，例如：5
// preConditionsEle *etree.Element changeSet的注释XML元素
// func GetPreConditionsSql(dbmsName *string, dbmsVersion *string, preConditionsEle *etree.Element) (string, error) {
// 	return utils.LF, nil
// }

// 检查校验前置条件
// <preConditions onFail="MARK_RAN" onError="HALT" onFailMessage="Fail Message", onErrorMessage="Error Message">
// 	<not>
// 		<indexExists indexName="idx_test_property_person_name" />
// 	</not>
// </preConditions>
// 返回为(true, nil)时，校验符合条件，执行该changeSet
// 返回为(false, nil)时，校验不符合条件，根据onFail配置执行
// 返回为(false, error)时，校验抛出异常，根据onError配置执行
func CheckPreConditions(dbConfig *config.DB, preConditionsEle *etree.Element) (bool, error) {
	onFail := preConditionsEle.SelectAttrValue("onFail", utils.PC_HALT)
	onError := preConditionsEle.SelectAttrValue("onError", utils.PC_HALT)
	onFailMessage := preConditionsEle.SelectAttrValue("onFailMessage", utils.EMPTY)
	onErrorMessage := preConditionsEle.SelectAttrValue("onErrorMessage", utils.EMPTY)
	conditionElems := preConditionsEle.ChildElements()
	for i := 0; i < len(conditionElems); i++ {
		resultBool, err := checkCondition(dbConfig, utils.LO_AND, conditionElems[i])
		if err != nil {
			// Sug.Errorf("CheckPreConditions -- resultBool : %v, err : %v", resultBool, err)
			if onErrorMessage != utils.EMPTY {
				Sug.Errorf("preConditions onErrorMessage: %v", onErrorMessage)
			}
			if onError == utils.PC_HALT {
				// HALT
				Sug.Errorf("preConditions error and halt: %v", conditionElems[i].Tag)
				panic(err)
			}
			return resultBool, err
		}
		if !resultBool {
			// preConditions标签的直接下级子标签，之间默认是与and的逻辑关系
			// 如果有一个为false，则可以直接返回false
			if onFailMessage != utils.EMPTY {
				Sug.Errorf("preConditions onFailMessage: %v", onFailMessage)
			}
			if onFail == utils.PC_HALT {
				// HALT
				panic(errors.New("preConditions failed and halt: " + conditionElems[i].Tag))
			}
			return false, nil
		}
	}
	return true, nil
}

// 检查前置条件标签中的一个子标签
// <not>/<and>/<or>/<columnExists>/<tableExists>/<viewExists>/<indexExists>/<sqlCheck>
func checkCondition(dbConfig *config.DB, logicalOperator string, conditionElem *etree.Element) (bool, error) {
	switch conditionElem.Tag {
	case "not":
		return checkLogicalCondition(dbConfig, utils.LO_NOT, conditionElem.ChildElements())
	case "and":
		return checkLogicalCondition(dbConfig, utils.LO_AND, conditionElem.ChildElements())
	case "or":
		return checkLogicalCondition(dbConfig, utils.LO_OR, conditionElem.ChildElements())
	default:
		// 实体标签<columnExists>/<tableExists>/...
		return checkEntityCondition(dbConfig, utils.LO_OR, conditionElem)
	}
	// return true, nil
}

// <not>/<and>/<or>
// childElems *[]etree.Element Elements in <not>/<and>/<or>
func checkLogicalCondition(dbConfig *config.DB, logicalOperator string, childElems []*etree.Element) (bool, error) {
	resultBools := make([]bool, 0, 5)
	for i := 0; i < len(childElems); i++ {
		resultBool, err := checkCondition(dbConfig, utils.LO_NOT, childElems[i])
		if err != nil {
			return resultBool, err
		}
		resultBools = append(resultBools, resultBool)
	}
	result := true
	for i := 0; i < len(resultBools); i++ {
		if !resultBools[i] {
			result = false
			break
		}
	}
	switch logicalOperator {
	case utils.LO_NOT:
		return !result, nil
	case utils.LO_AND:
		return result, nil
	case utils.LO_OR:
		result = false
		for i := 0; i < len(resultBools); i++ {
			if resultBools[i] {
				result = true
				break
			}
		}
		return result, nil
	default:
		Sug.Errorf("the logical operator is not supported: %v", logicalOperator)
		return false, nil
	}
}

// <columnExists>/<tableExists>/<viewExists>/<indexExists>/<sqlCheck>
// <columnExists tableName="t_user" columnName="username" />
func checkEntityCondition(dbConfig *config.DB, logicalOperator string, entityElem *etree.Element) (bool, error) {
	dbmsName := dbConfig.DbmsName
	switch entityElem.Tag {
	case "columnExists":
		// <columnExists tableName="t_user" columnName="username" />
		tableName := entityElem.SelectAttrValue("tableName", utils.EMPTY)
		SetPropertyValue(&dbmsName, &tableName)
		columnName := entityElem.SelectAttrValue("columnName", utils.EMPTY)
		SetPropertyValue(&dbmsName, &columnName)
		// checkSql := "select " + columnName + " from " + tableName + " where 1=2"
		checkSql, err := getColumnExistsSql(dbConfig, &tableName, &columnName)
		if err != nil {
			return false, err
		}
		return runCheckExistSql(&checkSql)
	case "tableExists":
		// <tableExists tableName="t_user" />
		tableName := entityElem.SelectAttrValue("tableName", utils.EMPTY)
		SetPropertyValue(&dbmsName, &tableName)
		// checkSql := "select count(9) from " + tableName + " where 1=2"
		checkSql, err := getTableExistsSql(dbConfig, &tableName)
		if err != nil {
			return false, err
		}
		return runCheckExistSql(&checkSql)
	case "viewExists":
		// <viewExists viewName="v_user_view" />
		viewName := entityElem.SelectAttrValue("viewName", utils.EMPTY)
		SetPropertyValue(&dbmsName, &viewName)
		// checkSql := "select count(9) from " + viewName + " where 1=2"
		checkSql, err := getViewExistsSql(dbConfig, &viewName)
		if err != nil {
			return false, err
		}
		return runCheckExistSql(&checkSql)
	case "indexExists":
		// <indexExists indexName="idx_user_name" />
		indexName := entityElem.SelectAttrValue("indexName", utils.EMPTY)
		SetPropertyValue(&dbmsName, &indexName)
		return checkIndexExistsCondition(dbConfig, &indexName)
	case "sqlCheck":
		// <sqlCheck expectedResult="1">select count(1) from pg_tables where tablename = 't_user'</sqlCheck>
		return checkSqlCheckCondition(&dbmsName, entityElem)
	default:
		// 暂不支持
		// panic(errors.New("The entity element is not supported: " + entityElem.Tag))
		// Sug.Errorf("The entity element is not supported: %v", entityElem.Tag)
		return false, errors.New("the entity element is not supported: " + entityElem.Tag)
	}
	// return false, nil
}

// 从mysql数据库的dns中获取数据库名称
// DNS: root:Shdb@5727@tcp(192.168.0.231:3016)/test_db?charset=utf8mb4&multiStatements=true
// return test_db
func getDbnameOfMysql(dataSourceName *string) string {
	begin := strings.Index(*dataSourceName, ")/")
	end := strings.Index(*dataSourceName, "?")
	return string((*dataSourceName)[begin+2 : end])
}

// 获取校验是否存在表字段的sql语句
func getColumnExistsSql(dbConfig *config.DB, tableName *string, columnName *string) (string, error) {
	var sql string
	dbmsName := dbConfig.DbmsName

	switch dbmsName {
	case MySQL, MariaDB, TiDB:
		dbname := getDbnameOfMysql(&dbConfig.DataSourceName)
		sql = "select count(9) from information_schema.columns where " +
			"table_schema = '" + dbname + "' and table_name = '" +
			*tableName + "' and column_name = '" + *columnName + "'"
	case PostgreSQL, Kingbase:
		sql = "select count(9) from pg_class a, pg_attribute b" +
			" where a.oid = b.attrelid" +
			" and a.relname = '" + *tableName + "'" +
			" and b.attname = '" + *columnName + "'"
	case Oracle:
		sql = "select count(9) from user_tab_columns where table_name = '" +
			strings.ToUpper(*tableName) + "' and column_name = '" + strings.ToUpper(*columnName) + "'"
	case Dameng:
		sql = "select count(9) from user_tab_columns where table_name = '" + *tableName + "' and column_name = '" + *columnName + "'"
	case SQLite:
		sql = "select count(9) from sqlite_master where type = 'table' and name = '" + *tableName + "' and sql like '% " + *columnName + " %'"
	case MsSQLServer:
		sqlTemp := `
			select count(9) from (
				select id from sysobjects where name = '%v' and xtype = 'U'
			) a,
			(
				select name, id parent_obj from syscolumns
				where name = '%v'
			) b
			where a.id = b.parent_obj`
		sql = fmt.Sprintf(sqlTemp, *tableName, *columnName)
	default:
		// Unsupported database type
		return sql, errors.New("unsupported database type: " + dbmsName)
	}
	return sql, nil
}

// 获取校验是否存在表的sql语句
func getTableExistsSql(dbConfig *config.DB, tableName *string) (string, error) {
	dbmsName := dbConfig.DbmsName
	var sql string
	switch dbmsName {
	case MySQL, MariaDB, TiDB:
		dbname := getDbnameOfMysql(&dbConfig.DataSourceName)
		sql = "select count(9) from information_schema.tables where table_schema = '" +
			dbname + "' and table_type = 'BASE TABLE' and table_name = '" + *tableName + "'"
	case PostgreSQL, Kingbase:
		sql = "select count(9) from pg_stat_user_tables where relname = '" + *tableName + "'"
	case Oracle:
		sql = "select count(9) from user_tables where table_name = '" + strings.ToUpper(*tableName) + "'"
	case Dameng:
		sql = "select count(9) from user_tables where table_name = '" + *tableName + "'"
	case SQLite:
		sql = "select count(9) from sqlite_master where type = 'table' and name = '" + *tableName + "'"
	case MsSQLServer:
		sql = "select count(9) from sysobjects where name = '" + *tableName + "' and xtype = 'U'"
	default:
		// Unsupported database type
		return sql, errors.New("unsupported database type: " + dbmsName)
	}
	return sql, nil
}

// 获取校验是否存在视图的sql语句
func getViewExistsSql(dbConfig *config.DB, viewName *string) (string, error) {
	dbmsName := dbConfig.DbmsName
	var sql string
	switch dbmsName {
	case MySQL, MariaDB, TiDB:
		dbname := getDbnameOfMysql(&dbConfig.DataSourceName)
		sql = "select count(9) from information_schema.views where table_schema = '" +
			dbname + "' and table_name = '" + *viewName + "'"
	case PostgreSQL, Kingbase:
		sql = "select count(9) from pg_views where viewname = '" + *viewName + "'"
	case Oracle:
		sql = "select count(9) from user_views where view_name = '" + strings.ToUpper(*viewName) + "'"
	case Dameng:
		sql = "select count(9) from user_views where view_name = '" + *viewName + "'"
	case SQLite:
		sql = "select count(9) from sqlite_master where type = 'view' and name = '" + *viewName + "'"
	case MsSQLServer:
		sql = "select count(9) from sysobjects where name = '" + *viewName + "' and xtype = 'V'"
	default:
		// Unsupported database type
		return sql, errors.New("unsupported database type: " + dbmsName)
	}
	return sql, nil
}

// 检查该索引是否存在
func checkIndexExistsCondition(dbConfig *config.DB, indexName *string) (bool, error) {
	dbmsName := dbConfig.DbmsName
	var sql string
	switch dbmsName {
	case MySQL, MariaDB, TiDB:
		dbname := getDbnameOfMysql(&dbConfig.DataSourceName)
		sql = "select count(9) from information_schema.statistics where table_schema = '" +
			dbname + "' and index_name = '" + *indexName + "'"
	case PostgreSQL, Kingbase:
		sql = "select count(9) from pg_stat_user_indexes where indexrelname = '" + *indexName + "'"
	case Oracle:
		sql = "select count(9) from user_indexes where index_name = '" + strings.ToUpper(*indexName) + "'"
	case Dameng:
		sql = "select count(9) from user_indexes where index_name = '" + *indexName + "'"
	case SQLite:
		sql = "select count(9) from sqlite_master where type = 'index' and name = '" + *indexName + "'"
	case MsSQLServer:
		sql = "select count(9) from sys.indexes where name = '" + *indexName + "'"
	default:
		// Unsupported database type
		// panic(errors.New("[indexExists] Unsupported database type: " + *dbmsName))
		return false, errors.New("unsupported database type: " + dbmsName)
	}
	return runCheckIndexExistSql(&sql)
}

// <sqlCheck expectedResult="1">select count(1) from pg_tables where tablename = 't_user'</sqlCheck>
// sqlCheck标签中的sql查询语句返回值仅支持整数int
func checkSqlCheckCondition(dbmsName *string, sqlCheckElem *etree.Element) (bool, error) {
	expectedResult := sqlCheckElem.SelectAttrValue("expectedResult", utils.EMPTY)
	sql := sqlCheckElem.Text()
	intValue, err := runSqlCheckSql(&sql)
	if err != nil {
		return false, err
	}
	expectedInt, errAtoi := strconv.Atoi(expectedResult)
	if errAtoi != nil {
		Sug.Errorf("run strconv.Atoi(%v) errAtoi %v\n", expectedResult, errAtoi)
		panic(err)
		return false, errAtoi
	}
	return intValue == expectedInt, nil
}

// 运行检查是否存在表字段、表以及视图的sql语句
// 返回false表示查询出错，不存在，返回true表示查询成功，即存在
// 其他异常则返回该异常
func runCheckExistSql(checkSql *string) (bool, error) {
	// var cvalue int
	_, err := db.Engine.Query(*checkSql) // .Scan(&cvalue)
	if err != nil {
		Sug.Errorf("run runCheckExistSql err %v\nsql: %v\n", err, *checkSql)
		if strings.Contains(err.Error(), "doesn't exist") || //
			strings.Contains(err.Error(), "Unknown column") || //
			strings.Contains(err.Error(), "no such column") || // sqlite
			strings.Contains(err.Error(), "不存在") || //
			strings.Contains(err.Error(), "no such table") || // sqlite
			strings.Contains(err.Error(), "标识符无效") || // oracle
			strings.Contains(err.Error(), "无效的") {
			// Table doesn't exist / Unknown column / *不存在 / 无效的表或视图名(列名) / 标识符无效
			// TODO 这里不应该使用错误的message来判定
			return false, nil
		}
		return false, err
	}
	return true, nil
}

// 运行检查某索引是否存在的sql语句
// 返回false表示不存在，返回true表示存在
// 发生异常则返回该异常
func runCheckIndexExistSql(checkSql *string) (bool, error) {
	var cvalue int
	_, err := db.Engine.SQL(*checkSql).Get(&cvalue)
	if err != nil {
		Sug.Errorf("run runCheckIndexExistSql err %v\nsql: %v\n", err, *checkSql)
		return false, err
	}
	return cvalue == utils.INT_ONE, nil
}

// 运行sqlCheck的sql语句
// 返回查询结果的整形数值
// 查询异常则抛出该异常
func runSqlCheckSql(checkSql *string) (int, error) {
	var cvalue int
	_, err := db.Engine.SQL(*checkSql).Get(&cvalue)
	if err != nil {
		Sug.Errorf("run runSqlCheckSql err %v\nsql: %v\n", err, *checkSql)
		return utils.INT_ZERO, err
	}
	return cvalue, nil
}

// 从results([]map[string][]byte)中获取唯一一条记录的唯一一个字段的整数值
// func getSqlCheckValue(results *[]map[string][]byte) (int, error) {
// 	for _, map1 := range *results {
// 		for _, mvalue := range map1 {
// 			return strconv.Atoi(string(mvalue))
// 		}
// 	}
// 	return utils.INT_ZERO, nil
// }
